import torch.nn.functional as F


def ce_loss(pred, label):
    loss = F.cross_entropy(pred, label)
    return loss


def bce_loss(pred, label):
    loss = F.binary_cross_entropy(pred, label)
    return loss


def mse_loss(pred, label):
    loss = F.mse_loss(pred, label)
    return loss


loss_dict = {'ce_loss': ce_loss, 'bce_loss': bce_loss, 'mse_loss': mse_loss}

